{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adversarial Training\n",
    "\n",
    "This notebook shows the creation of an adversarial training methodology to harden a neural network against a digital FGSM attack.\n",
    "\n",
    "### Assumptions\n",
    "- The dataset wrangling has already been completed (and is provided here)\n",
    "- The adversarial attack (FGSM) has already been completed\n",
    "- The outer loop of training has already been completed and we're only subclassing a single epoch\n",
    "- The plotting code has already been completed\n",
    "\n",
    "### Components Recreated in Tutorial\n",
    "- Adversarial training constrained by a power ratio and a percentage of the dataset it alters each epoch.\n",
    "\n",
    "### See Also\n",
    "The code in this tutorial is a stripped down version of the code in ``rfml.nn.train.adversarial`` that simplifies discussion.  Further detail can be provided by directly browsing the source files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install the library code\n",
    "!pip install git+https://github.com/brysef/rfml.git@1.0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting Includes\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# External Includes\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# Internal Includes\n",
    "from rfml.attack import fgsm\n",
    "\n",
    "from rfml.data import Dataset, Encoder\n",
    "from rfml.data import build_dataset\n",
    "\n",
    "from rfml.nbutils import plot_acc_vs_spr, plot_acc_vs_snr\n",
    "\n",
    "from rfml.nn.eval import compute_accuracy, compute_accuracy_on_cross_sections\n",
    "from rfml.nn.model import build_model, Model\n",
    "from rfml.nn.train import StandardTrainingStrategy, PrintingTrainingListener"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gpu = True       # Set to True to use a GPU for training\n",
    "fig_dir = None   # Set to a file path if you'd like to save the plots generated\n",
    "data_path = None # Set to a file path if you've downloaded RML2016.10A locally"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Adversarial Training of a Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset and a DNN model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, val, test, le = build_dataset(\"RML2016.10a\", path=data_path)\n",
    "# as_numpy returns x,y and x is shape BxCxIQxN\n",
    "input_samples = val.as_numpy(le=le)[0].shape[3]\n",
    "model = build_model(model_name=\"CNN\", input_samples=input_samples, n_classes=len(le))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating our own adversarial trainer\n",
    "\n",
    "One of the most effective, and simple, methodologies for hardening deep learning models against adversarial attacks is simply \"showing\" them what they are.  A process known as adversarial training.\n",
    "\n",
    "Here, we recreate the adversarial training from [Kurakin et al.] that uses the FGSM attack from [Goodfellow et al.] to augment the training examples with adversarial examples.  Note that this coupling of adversarial attack and training was found, in the context of computer vision, to produce robustness which was misleading because it was actually learning to obfuscate the gradient used as a \"signal\" to create the adversarial example and not necessarily becoming robust to the attack [Tramer et al.].  The adversarial training methodology was then extended to Ensemble Adversarial Training by [Tramer et al.], however, this notebook only demonstrates the adversarial training proposed in [Kurakin et al.] as it can be more easily self-contained into a notebook for demonstration.\n",
    "\n",
    "Also note that adversarial training was applied in the context of RF in [Kokalj-Filipovic and Miller].\n",
    "\n",
    "#### Goodfellow et al.\n",
    "Goodfellow, I., Shlens, J., and Szegedy, C. (2015). Explaining and harnessing adversarial examples. In Int. Conf. on Learning Representations.\n",
    "\n",
    "#### Kurakin et al.\n",
    "\n",
    "Kurakin, A., Goodfellow, I. J., and Bengio, S. (2016).  Adversarial machine learning at scale.CoRR, abs/1611.01236.\n",
    "\n",
    "#### Tramer et. al\n",
    "Tramer, F., Kurakin, A., Papernot, N., Boneh, D., and McDaniel, P. D. (2017). Ensemble adversarial training: Attacks and defenses.CoRR, abs/1705.07204.\n",
    "\n",
    "#### Kokalj-Filipovic and Miller\n",
    "\n",
    "Kokalj-Filipovic, S. and Miller, R. (2019). Adversarial examples in RF deep learning: Detection of the attack and its physical robustness.CoRR, abs/1902.06044."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyAdversarialTrainingStrategy(StandardTrainingStrategy):\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        lr: float = 10e-4,\n",
    "        max_epochs: int = 50,\n",
    "        patience: int = 5,\n",
    "        batch_size: int = 512,\n",
    "        gpu: bool = True,\n",
    "        k: float = 0.05,\n",
    "        spr: float = 10.0,\n",
    "    ):\n",
    "        super().__init__(\n",
    "            lr=lr,\n",
    "            max_epochs=max_epochs,\n",
    "            patience=patience,\n",
    "            batch_size=batch_size,\n",
    "            gpu=gpu,\n",
    "        )\n",
    "        self.k = k\n",
    "        self.spr = spr\n",
    "\n",
    "        # The exact value of the sps shouldn't actually matter.  It's simply used for\n",
    "        # an intermediate scaling of the example before applying the adversarial\n",
    "        # perturbation with FGSM.  This assumption that it shouldn't matter is based\n",
    "        # upon the expectation the model does the normalization as the first \"layer\"\n",
    "        # in its network.\n",
    "        self.sps = 8\n",
    "\n",
    "    def _train_one_epoch(\n",
    "        self, model: Model, data: DataLoader, loss_fn: CrossEntropyLoss, optimizer: Adam\n",
    "    ) -> float:\n",
    "        total_loss = 0.0\n",
    "        # Switch the model mode so it remembers gradients, induces dropout, etc.\n",
    "        model.train()\n",
    "\n",
    "        for i, batch in enumerate(data):\n",
    "            x, y = batch\n",
    "\n",
    "            # Perform adversarial augmentation in the training loop using FGSM\n",
    "            x = self._adversarial_augmentation(x=x, y=y, model=model)\n",
    "\n",
    "            # Push data to GPU\n",
    "            if self.gpu:\n",
    "                x = Variable(x.cuda())\n",
    "                y = Variable(y.cuda())\n",
    "            else:\n",
    "                x = Variable(x)\n",
    "                y = Variable(y)\n",
    "\n",
    "            # Forward pass of prediction -- while some are adversarial\n",
    "            outputs = model(x)\n",
    "\n",
    "            # Zero out the parameter gradients, because they are cumulative,\n",
    "            # compute loss, compute gradients (backward), update weights\n",
    "            loss = loss_fn(outputs, y)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss += loss.item()\n",
    "\n",
    "        mean_loss = total_loss / (i + 1.0)\n",
    "        return mean_loss\n",
    "\n",
    "    def _adversarial_augmentation(\n",
    "        self, x: torch.Tensor, y: torch.Tensor, model: Model\n",
    "    ) -> torch.Tensor:\n",
    "        # Rely on the fact that the DataLoader shuffles -- therefore can just take the\n",
    "        # first *n* examples and perform adversarial augmentation on it and it will be\n",
    "        # a random selection.\n",
    "        n_adversarial = int(self.k * x.shape[0])\n",
    "        if n_adversarial == 0:\n",
    "            return x\n",
    "\n",
    "        x[0:n_adversarial, ::] = fgsm(\n",
    "            x=x[0:n_adversarial, ::],\n",
    "            y=y[0:n_adversarial],\n",
    "            net=model,\n",
    "            spr=self.spr,\n",
    "            sps=self.sps,\n",
    "        )\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = MyAdversarialTrainingStrategy(max_epochs=10,\n",
    "                                        patience=3,\n",
    "                                        gpu=gpu,\n",
    "                                        k=0.25,\n",
    "                                        spr=10)\n",
    "trainer.register_listener(PrintingTrainingListener())\n",
    "trainer(model=model, training=train, validation=val, le=le)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing the Adversarial Trained Model on normal data\n",
    "This ensures that we haven't completely sacraficed our performance in the baseline case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc = compute_accuracy(model=model, data=test, le=le)\n",
    "print(\"Overall Testing Accuracy: {:.4f}\".format(acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_vs_snr, snr = compute_accuracy_on_cross_sections(model=model,\n",
    "                                                     data=test,\n",
    "                                                     le=le,\n",
    "                                                     column=\"SNR\")\n",
    "\n",
    "title = \"Accuracy vs SNR of {model_name} on {dataset_name}\".format(model_name=\"CNN\", dataset_name=\"RML2016.10A\")\n",
    "fig = plot_acc_vs_snr(acc_vs_snr=acc_vs_snr, snr=snr, title=title)\n",
    "if fig_dir is not None:\n",
    "    file_path = \"{fig_dir}/hardened_acc_vs_snr.pdf\"\n",
    "    print(\"Saving Figure -> {file_path}\".format(file_path=file_path))\n",
    "    fig.savefig(file_path, format=\"pdf\", transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verifying the Model has been Hardened\n",
    "Attempt Evading Signal Classification with Direct Access to the Classifier again to see if we have improved over our prior attack."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask = test.df[\"SNR\"] >= 18\n",
    "dl = DataLoader(test.as_torch(le=le, mask=mask), shuffle=True, batch_size=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure that the model is in \"evaluation\" mode\n",
    "# -- Therefore batch normalization is not computed and dropout is not performed\n",
    "# -- Note: This is the cause of a lot of bugs\n",
    "model.eval()\n",
    "\n",
    "acc_vs_spr = list()\n",
    "sprs = list()\n",
    "\n",
    "for spr in np.linspace(50.0, 0.0, num=26):\n",
    "    right = 0\n",
    "    total = 0\n",
    "    \n",
    "    for x, y in dl:\n",
    "        adv_x = fgsm(x, y, spr=spr, input_size=input_samples, sps=8, net=model)\n",
    "\n",
    "        predictions = model.predict(adv_x)\n",
    "        right += (predictions == y).sum().item()\n",
    "        total += len(y)\n",
    "\n",
    "    acc = float(right) / total\n",
    "    acc_vs_spr.append(acc)\n",
    "    sprs.append(spr)\n",
    "\n",
    "fig = plot_acc_vs_spr(acc_vs_spr=acc_vs_spr,\n",
    "                      spr=sprs,\n",
    "                      title=\"Performance of a Digital FGSM Attack after Hardening\"\n",
    "                     )\n",
    "if fig_dir is not None:\n",
    "    file_path = \"{fig_dir}/hardened_direct_access_fgsm.pdf\".format(fig_dir=fig_dir)\n",
    "    print(\"Saving Figure -> {file_path}\".format(file_path=file_path))\n",
    "    fig.savefig(file_path, format=\"pdf\", transparent=True)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
